-
Notifications
You must be signed in to change notification settings - Fork 391
[Feature] Add support for loading datasets from local Minari cache #3068
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/3068
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
||
dataset_id = "cartpole/test-local-v1" | ||
|
||
# Create dataset using Gym + DataCollector |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Custom minari dataset creation from a gymnasium environment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's really handy I love it!
Do you think we have an opportunity here to reduce the number of datasets we download to test minari, and use customly built datasets instead like you do in your test? |
Absolutely, instead of downloading 20 datasets from the Minari server, we could generate smaller, custom datasets from any gymnasium environment as part of our tests. This approach is especially valuable for D4RL datasets, which tend to be very large and can significantly slow down testing. |
test/test_libs.py
Outdated
@@ -29,9 +29,12 @@ | |||
from sys import platform | |||
from unittest import mock | |||
|
|||
import minari |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a global import - we must avoid these at all cost.
Can you make it local?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I can fix it
Do you want to give it a go? Otherwise I can do it no worry |
Yes, no problem |
@vmoens Do you think we should replace testing with all Minari datasets by using smaller, custom datasets generated from gymnasium environments? Or should we still download and test with some datasets from the Minari server to ensure that the downloading functionality in MinariExperienceReplay works correctly? |
Maybe just one defined (not random) dataset - to make sure? |
Ok, I agree that testing with a single, well-defined dataset makes sense to ensure the download and loading functionality are covered. However, with most gym environments, creating custom datasets isn’t always straightforward. For example, I ran into this error:
I’ll keep investigating to see if there’s a workaround or a more general approach that works across environments. |
Ok maybe we could proceed with this in its current state and refactor the tests later! |
Ok, I have refactored the tests to use this custom datasets approach. I am still fixing some minor errors and making sure it all works. |
If you're ready feel free to commit here directly |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
This actually makes me think that we should have a similar pipeline where one creates a datacollector with some replay buffer and serializes it as a dataset later on.
It's mostly a matter of documenting it rather than creating the feature, everything already exists I think!
Ok, should I provide a tutorial documenting the pipeline from DataCollector with a replay buffer to dataset serialization? |
First let's fix this! |
I already import minari and ale_py only inside the specific functions or test cases that require them, rather than at the global level. |
Do you think moving all imports to the _minari_init which only makes necessary imports if dependencies are installed would solve this? |
No imports must happen if and only if they're needed. Otherwise spawned multiprocessed jobs all import a ton of useless libs (if they're installed). |
@@ -3341,6 +3341,39 @@ def test_d4rl_iteration(self, task, split_trajs): | |||
|
|||
_MINARI_DATASETS = [] | |||
|
|||
MUJOCO_ENVIRONMENTS = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Explicitly specified current set of Minari supported datasets for integration with Gym environments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Specified datasets are those used in Mujoco and D4RL
# Initialize with placeholder values for parametrization | ||
# These will be replaced with actual dataset names when the first Minari test runs | ||
_MINARI_DATASETS = [str(i) for i in range(20)] | ||
def get_random_minigrid_datasets(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only use Minari datasets from Minigrid
This is because current version of Minari cannot serialize custom MissionSpace objects, which are used in most Minigrid environments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This means you cannot create your custom dataset from a minigrid environment directly; you will need to modify Mission space
cleanup_needed = False | ||
|
||
else: | ||
# Atari environment datasets |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For Atari environments, select a random subset of environments.
test/test_libs.py
Outdated
@@ -29,11 +29,9 @@ | |||
from sys import platform | |||
from unittest import mock | |||
|
|||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed global imports
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: new dependencies were added for creating custom Minari datasets from gym environments:
- ale_py (for Atari)
- gymnasium_robotics (for D4RL & Mujoco)
test/test_libs.py
Outdated
range( | ||
len(MUJOCO_ENVIRONMENTS) | ||
+ len(D4RL_ENVIRONMENTS) | ||
+ len(get_random_minigrid_datasets()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These two functions import minari and are called at the module level (global workspace), which causes the error during CI.
I will refactor them so that the imports and calls only occur within the relevant test functions.
test/test_libs.py
Outdated
@@ -3488,12 +3488,7 @@ class TestMinari: | |||
@pytest.mark.parametrize("split", [False, True]) | |||
@pytest.mark.parametrize( | |||
"dataset_idx", | |||
range( | |||
len(MUJOCO_ENVIRONMENTS) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These decorators were called in global workspace and imported minari and ale_py internally
Now, I use a static upper bound, so that no function that imports minari or ale_py globally is called
Cool! Looks like the conflicts with other CI runs are resolved. We have some long-lasting failures but I wouldn't pay too much attention to them. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reduce number of fetched datasets to avoid 429 Error
@@ -21,3 +21,5 @@ dependencies: | |||
- hydra-core | |||
- minari[gcs,hdf5,hf] | |||
- gymnasium<1.0.0 | |||
- ale-py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added new dependencies
Description
This PR adds support for loading datasets directly from the local Minari cache in the
MinariExperienceReplay
class. Specifically, it introduces theload_from_local_minari
argument, which, when set toTrue
, instructs the class to load the dataset from the user's local Minari cache (typically at~/.minari/datasets
) and skip any fetching from the Minari server (i.e., no remote download or overwrite will occur). After loading from the local cache, all subsequent preprocessing and loading steps continue as usual, ensuring the dataset is processed and made available correctly. This is especially useful for custom Minari datasets or when you want to avoid network access.Documentation has been updated in the class docstring to clearly state the new behavior of the
load_from_local_minari
argument, including details about local cache prioritization and unchanged downstream preprocessing.This PR also includes comprehensive test coverage for the new feature, confirming that datasets created and stored in the local Minari cache can be loaded, sampled, and validated for correctness using the new argument. The provided test (
test_local_minari_dataset_loading
) creates a custom dataset, loads it from cache, verifies sample integrity, and cleans up afterwards.Motivation and Context
Previously,
MinariExperienceReplay
required datasets to be downloaded via its own interface, which was incompatible with custom or preloaded Minari datasets, attempting to load these would result in aFileNotFoundError
.This change allows users to work with their own datasets, datasets created with
minari.DataCollector(...).create_dataset(...)
, or any dataset present in the local Minari cache, without requiring redundant downloads or manual metadata copying.If the dataset is not found in the local cache, a
FileNotFoundError
is raised with a clear message.Solves #3067
Types of changes
Checklist